package tensorflow

import (
	"path/filepath"
	"slices"
	"testing"

	"github.com/photoprism/photoprism/pkg/fs"
)

var assetsPath = fs.Abs("../../../assets")
var testDataPath = fs.Abs("testdata")

func TestTF1ModelLoad(t *testing.T) {
	model, err := SavedModel(
		filepath.Join(assetsPath, "models", "nasnet"),
		[]string{"photoprism"})

	if err != nil {
		t.Fatal(err)
	}

	_, _, err = GetInputAndOutputFromSavedModel(model)
	if err == nil {
		t.Fatalf("TF1 does not have signatures, but GetInput worked")
	}

	input, output, err := GuessInputAndOutput(model)
	if err != nil {
		t.Fatal(err)
	}

	switch {
	case input == nil:
		t.Fatal("Could not get the input")
	case output == nil:
		t.Fatal("Could not get the output")
	case input.Shape == nil:
		t.Fatal("Could not get the shape")
	default:
		t.Logf("Shape: %v", input.Shape)
	}
}

func TestTF2ModelLoad(t *testing.T) {
	model, err := SavedModel(
		filepath.Join(testDataPath, "tf2"),
		[]string{"serve"})

	if err != nil {
		t.Fatal(err)
	}

	input, output, err := GetInputAndOutputFromSavedModel(model)
	if err != nil {
		t.Fatal(err)
	}

	switch {
	case input == nil:
		t.Fatal("Could not get the input")
	case output == nil:
		t.Fatal("Could not get the output")
	case input.Shape == nil:
		t.Fatal("Could not get the shape")
	case !slices.Equal(input.Shape, DefaultPhotoInputShape()):
		t.Fatalf("Invalid shape calculated. Expected BHWC, got %v", input.Shape)
	}
}

func TestTF2ModelBCHWLoad(t *testing.T) {
	model, err := SavedModel(
		filepath.Join(testDataPath, "tf2_bchw"),
		[]string{"serve"})

	if err != nil {
		t.Fatal(err)
	}

	input, output, err := GetInputAndOutputFromSavedModel(model)
	if err != nil {
		t.Fatal(err)
	}

	switch {
	case input == nil:
		t.Fatal("Could not get the input")
	case output == nil:
		t.Fatal("Could not get the output")
	case input.Shape == nil:
		t.Fatal("Could not get the shape")
	case !slices.Equal(input.Shape, []ShapeComponent{
		ShapeBatch, ShapeColor, ShapeHeight, ShapeWidth,
	}):
		t.Fatalf("Invalid shape calculated. Expected BCHW, got %v", input.Shape)
	}
}
